[94d9b6]: / Experimentations / Exp14-Label smoothing and noisy ground truth.ipynb

Download this file

520 lines (519 with data), 31.9 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n",
    "We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a  way to segment images based on noisy ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torchvision import datasets, models\n",
    "from torchvision import transforms as T\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import time\n",
    "import pandas as pd\n",
    "from skimage import io, transform\n",
    "import matplotlib.image as mpimg\n",
    "from PIL import Image\n",
    "from sklearn.metrics import roc_auc_score\n",
    "import torch.nn.functional as F\n",
    "import scipy\n",
    "import random\n",
    "import pickle\n",
    "import scipy.io as sio\n",
    "import itertools\n",
    "from scipy.ndimage.interpolation import shift\n",
    "import copy\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "%matplotlib inline\n",
    "plt.ion()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import Dataloader Class and other utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from dataloader_2d import *\n",
    "from dataloader_3d import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build Data loader objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_path = '/beegfs/ark576/new_knee_data/train'\n",
    "val_path = '/beegfs/ark576/new_knee_data/val'\n",
    "test_path = '/beegfs/ark576/new_knee_data/test'\n",
    "\n",
    "train_file_names = sorted(pickle.load(open(train_path + '/train_file_names.p','rb')))\n",
    "val_file_names = sorted(pickle.load(open(val_path + '/val_file_names.p','rb')))\n",
    "test_file_names = sorted(pickle.load(open(test_path + '/test_file_names.p','rb')))\n",
    "\n",
    "transformed_dataset = {'train': KneeMRIDataset(train_path,train_file_names, train_data= True, flipping=False, normalize= True),\n",
    "                       'validate': KneeMRIDataset(val_path,val_file_names, normalize= True),\n",
    "                       'test': KneeMRIDataset(test_path,test_file_names, normalize= True)\n",
    "                                          }\n",
    "\n",
    "dataloader = {x: DataLoader(transformed_dataset[x], batch_size=5,\n",
    "                        shuffle=True, num_workers=0) for x in ['train', 'validate','test']}\n",
    "data_sizes ={x: len(transformed_dataset[x]) for x in ['train', 'validate','test']}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "im, seg_F, seg_P, seg_T,_ = next(iter(dataloader['train']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Find Max and min values of Images (all 7 contrasts), of Fractional Anisotropy maps and of Mean Diffusivity maps for image normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "min_fa = np.inf\n",
    "min_md = np.inf\n",
    "min_image = np.inf\n",
    "max_fa = 0\n",
    "max_md = 0\n",
    "max_image = 0\n",
    "for data in dataloader['train']:\n",
    "    if min_fa > torch.min(data[0][:,7,:,:]):\n",
    "        min_fa = torch.min(data[0][:,7,:,:])\n",
    "    if min_md > torch.min(data[0][:,8,:,:]):\n",
    "        min_md = torch.min(data[0][:,8:,:])\n",
    "    if min_image > torch.min(data[0][:,:7,:,:]):\n",
    "        min_image = torch.min(data[0][:,:7,:,:])\n",
    "    if max_fa < torch.max(data[0][:,7,:,:]):\n",
    "        max_fa = torch.max(data[0][:,7,:,:])\n",
    "    if max_md < torch.max(data[0][:,8,:,:]):\n",
    "        max_md = torch.max(data[0][:,8,:,:])\n",
    "    if max_image < torch.max(data[0][:,:7,:,:]):\n",
    "        max_image = torch.max(data[0][:,:7,:,:])\n",
    "norm_values = (max_image, min_image, max_fa, min_fa, max_md, min_md)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from unet_3d import *\n",
    "from unet_basic_dilated import *\n",
    "from vnet import *\n",
    "from ensemble_model import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "seg_sum = torch.zeros(3)\n",
    "for i, data in enumerate(dataloader['train']):\n",
    "    input, segF, segP, segT,_ = data\n",
    "    seg_sum[0] += torch.sum(segF)\n",
    "    seg_sum[1] += torch.sum(segP)\n",
    "    seg_sum[2] += torch.sum(segT)\n",
    "mean_s_sum = seg_sum/i"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Loss functions and all other utility functions like functions for saving models, for visualizing images, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import all the Training and evaluate functions to evaluate the models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from train_2d import *\n",
    "from train_3d import *\n",
    "from train_ensemble import *\n",
    "from evaluate_2d import *\n",
    "from evaluate_3d import *\n",
    "from evaluate_ensemble import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 14. Experimenting with label smoothing and CEL to deal with noisy ground truth\n",
    "We tried to train a network with label smoothing, which is generally done when the ground truth is noisy or involves a lot of subjectivity. The practice of label smoothing is tried for classification problem but never for a segmentation problem. We tried it for segmentation problem. It didn't seem to work well and hence training was stopped mid-way as the dice-scores were no where close to acceptable levels. As a future work, we would like to engineer a  way to segment images based on noisy ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "unet_exp_noisy = Unet_dilated_small(9,4,int_var=40,dilated=False).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "410524"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_parameters(unet_exp_noisy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "optimizer_unet_exp_noisy = optim.Adam(unet_exp_noisy.parameters(),lr = 1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Phase: train, epoch loss: 0.0069, Dice Score (class 0): 0.4548, Dice Score (class 1): 0.3627,Dice Score (class 2): 0.2732\n",
      "----------\n",
      "Epoch: 0, Phase: validate, epoch loss: 0.0107, Dice Score (class 0): 0.4652, Dice Score (class 1): 0.2242,Dice Score (class 2): 0.2691\n",
      "----------\n",
      "Epoch: 1, Phase: train, epoch loss: 0.0066, Dice Score (class 0): 0.4464, Dice Score (class 1): 0.3631,Dice Score (class 2): 0.2697\n",
      "----------\n",
      "Epoch: 1, Phase: validate, epoch loss: 0.0116, Dice Score (class 0): 0.4738, Dice Score (class 1): 0.1400,Dice Score (class 2): 0.2687\n",
      "----------\n",
      "Epoch: 2, Phase: train, epoch loss: 0.0063, Dice Score (class 0): 0.4417, Dice Score (class 1): 0.3004,Dice Score (class 2): 0.2691\n",
      "----------\n",
      "Epoch: 2, Phase: validate, epoch loss: 0.0083, Dice Score (class 0): 0.4532, Dice Score (class 1): 0.1267,Dice Score (class 2): 0.2386\n",
      "----------\n",
      "Epoch: 3, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4365, Dice Score (class 1): 0.3346,Dice Score (class 2): 0.2603\n",
      "----------\n",
      "Epoch: 3, Phase: validate, epoch loss: 0.0090, Dice Score (class 0): 0.4624, Dice Score (class 1): 0.1272,Dice Score (class 2): 0.2762\n",
      "----------\n",
      "Epoch: 4, Phase: train, epoch loss: 0.0058, Dice Score (class 0): 0.4264, Dice Score (class 1): 0.2908,Dice Score (class 2): 0.2789\n",
      "----------\n",
      "Epoch: 4, Phase: validate, epoch loss: 0.0106, Dice Score (class 0): 0.4784, Dice Score (class 1): 0.0917,Dice Score (class 2): 0.2196\n",
      "----------\n",
      "Epoch: 5, Phase: train, epoch loss: 0.0053, Dice Score (class 0): 0.4341, Dice Score (class 1): 0.3044,Dice Score (class 2): 0.2540\n",
      "----------\n",
      "Epoch: 5, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4828, Dice Score (class 1): 0.1812,Dice Score (class 2): 0.2497\n",
      "----------\n",
      "Epoch: 6, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.4285, Dice Score (class 1): 0.2792,Dice Score (class 2): 0.2567\n",
      "----------\n",
      "Epoch: 6, Phase: validate, epoch loss: 0.0081, Dice Score (class 0): 0.4304, Dice Score (class 1): 0.0803,Dice Score (class 2): 0.2250\n",
      "----------\n",
      "Epoch: 7, Phase: train, epoch loss: 0.0049, Dice Score (class 0): 0.4271, Dice Score (class 1): 0.3015,Dice Score (class 2): 0.2550\n",
      "----------\n",
      "Epoch: 7, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4339, Dice Score (class 1): 0.1199,Dice Score (class 2): 0.2402\n",
      "----------\n",
      "Epoch: 8, Phase: train, epoch loss: 0.0046, Dice Score (class 0): 0.4336, Dice Score (class 1): 0.3101,Dice Score (class 2): 0.2492\n",
      "----------\n",
      "Epoch: 8, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4146, Dice Score (class 1): 0.0966,Dice Score (class 2): 0.2273\n",
      "----------\n",
      "Epoch: 9, Phase: train, epoch loss: 0.0047, Dice Score (class 0): 0.4042, Dice Score (class 1): 0.3042,Dice Score (class 2): 0.2402\n",
      "----------\n",
      "Epoch: 9, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4546, Dice Score (class 1): 0.0937,Dice Score (class 2): 0.2365\n",
      "----------\n",
      "Epoch: 10, Phase: train, epoch loss: 0.0048, Dice Score (class 0): 0.4017, Dice Score (class 1): 0.2660,Dice Score (class 2): 0.2579\n",
      "----------\n",
      "Epoch: 10, Phase: validate, epoch loss: 0.0093, Dice Score (class 0): 0.3297, Dice Score (class 1): 0.1008,Dice Score (class 2): 0.3376\n",
      "----------\n",
      "Epoch: 11, Phase: train, epoch loss: 0.0052, Dice Score (class 0): 0.3732, Dice Score (class 1): 0.2240,Dice Score (class 2): 0.2164\n",
      "----------\n",
      "Epoch: 11, Phase: validate, epoch loss: 0.0080, Dice Score (class 0): 0.3507, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.1992\n",
      "----------\n",
      "Epoch: 12, Phase: train, epoch loss: 0.0044, Dice Score (class 0): 0.4045, Dice Score (class 1): 0.2663,Dice Score (class 2): 0.2398\n",
      "----------\n",
      "Epoch: 12, Phase: validate, epoch loss: 0.0054, Dice Score (class 0): 0.3737, Dice Score (class 1): 0.1181,Dice Score (class 2): 0.2016\n",
      "----------\n",
      "Epoch: 13, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.4121, Dice Score (class 1): 0.2953,Dice Score (class 2): 0.2487\n",
      "----------\n",
      "Epoch: 13, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4482, Dice Score (class 1): 0.1157,Dice Score (class 2): 0.2612\n",
      "----------\n",
      "Epoch: 14, Phase: train, epoch loss: 0.0037, Dice Score (class 0): 0.4235, Dice Score (class 1): 0.3123,Dice Score (class 2): 0.2620\n",
      "----------\n",
      "Epoch: 14, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4270, Dice Score (class 1): 0.0954,Dice Score (class 2): 0.2708\n",
      "----------\n",
      "Epoch: 15, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.4144, Dice Score (class 1): 0.2824,Dice Score (class 2): 0.2676\n",
      "----------\n",
      "Epoch: 15, Phase: validate, epoch loss: 0.0076, Dice Score (class 0): 0.4503, Dice Score (class 1): 0.1200,Dice Score (class 2): 0.2609\n",
      "----------\n",
      "Epoch: 16, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4183, Dice Score (class 1): 0.2898,Dice Score (class 2): 0.2592\n",
      "----------\n",
      "Epoch: 16, Phase: validate, epoch loss: 0.0061, Dice Score (class 0): 0.4537, Dice Score (class 1): 0.1223,Dice Score (class 2): 0.3014\n",
      "----------\n",
      "Epoch: 17, Phase: train, epoch loss: 0.0035, Dice Score (class 0): 0.4004, Dice Score (class 1): 0.2839,Dice Score (class 2): 0.2320\n",
      "----------\n",
      "Epoch: 17, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4409, Dice Score (class 1): 0.1723,Dice Score (class 2): 0.3229\n",
      "----------\n",
      "Epoch: 18, Phase: train, epoch loss: 0.0034, Dice Score (class 0): 0.4095, Dice Score (class 1): 0.2638,Dice Score (class 2): 0.2377\n",
      "----------\n",
      "Epoch: 18, Phase: validate, epoch loss: 0.0124, Dice Score (class 0): 0.4778, Dice Score (class 1): 0.0729,Dice Score (class 2): 0.2672\n",
      "----------\n",
      "Epoch: 19, Phase: train, epoch loss: 0.0032, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.3173,Dice Score (class 2): 0.2508\n",
      "----------\n",
      "Epoch: 19, Phase: validate, epoch loss: 0.0087, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.2357,Dice Score (class 2): 0.3106\n",
      "----------\n",
      "Epoch: 20, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.4090, Dice Score (class 1): 0.3206,Dice Score (class 2): 0.2621\n",
      "----------\n",
      "Epoch: 20, Phase: validate, epoch loss: 0.0053, Dice Score (class 0): 0.4091, Dice Score (class 1): 0.1603,Dice Score (class 2): 0.2504\n",
      "----------\n",
      "Epoch: 21, Phase: train, epoch loss: 0.0039, Dice Score (class 0): 0.3843, Dice Score (class 1): 0.2829,Dice Score (class 2): 0.2470\n",
      "----------\n",
      "Epoch: 21, Phase: validate, epoch loss: 0.0112, Dice Score (class 0): 0.3247, Dice Score (class 1): 0.0295,Dice Score (class 2): 0.2692\n",
      "----------\n",
      "Epoch: 22, Phase: train, epoch loss: 0.0036, Dice Score (class 0): 0.3750, Dice Score (class 1): 0.1967,Dice Score (class 2): 0.2209\n",
      "----------\n",
      "Epoch: 22, Phase: validate, epoch loss: 0.0075, Dice Score (class 0): 0.4240, Dice Score (class 1): 0.1111,Dice Score (class 2): 0.3162\n",
      "----------\n",
      "Epoch: 23, Phase: train, epoch loss: 0.0030, Dice Score (class 0): 0.3883, Dice Score (class 1): 0.2905,Dice Score (class 2): 0.2495\n",
      "----------\n",
      "Epoch: 23, Phase: validate, epoch loss: 0.0079, Dice Score (class 0): 0.4247, Dice Score (class 1): 0.1192,Dice Score (class 2): 0.2906\n",
      "----------\n",
      "Epoch: 24, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4021, Dice Score (class 1): 0.3212,Dice Score (class 2): 0.2533\n",
      "----------\n",
      "Epoch: 24, Phase: validate, epoch loss: 0.0067, Dice Score (class 0): 0.4512, Dice Score (class 1): 0.2062,Dice Score (class 2): 0.3409\n",
      "----------\n",
      "Epoch: 25, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.4044, Dice Score (class 1): 0.3422,Dice Score (class 2): 0.2634\n",
      "----------\n",
      "Epoch: 25, Phase: validate, epoch loss: 0.0094, Dice Score (class 0): 0.4467, Dice Score (class 1): 0.3159,Dice Score (class 2): 0.3170\n",
      "----------\n",
      "Epoch: 26, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.4007, Dice Score (class 1): 0.3061,Dice Score (class 2): 0.2244\n",
      "----------\n",
      "Epoch: 26, Phase: validate, epoch loss: 0.0062, Dice Score (class 0): 0.4598, Dice Score (class 1): 0.0884,Dice Score (class 2): 0.1960\n",
      "----------\n",
      "Epoch: 27, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3824, Dice Score (class 1): 0.2838,Dice Score (class 2): 0.2146\n",
      "----------\n",
      "Epoch: 27, Phase: validate, epoch loss: 0.0100, Dice Score (class 0): 0.4430, Dice Score (class 1): 0.1496,Dice Score (class 2): 0.2968\n",
      "----------\n",
      "Epoch: 28, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3903, Dice Score (class 1): 0.3193,Dice Score (class 2): 0.2287\n",
      "----------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 28, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.3152, Dice Score (class 1): 0.1829,Dice Score (class 2): 0.2807\n",
      "----------\n",
      "Epoch: 29, Phase: train, epoch loss: 0.0028, Dice Score (class 0): 0.3628, Dice Score (class 1): 0.2731,Dice Score (class 2): 0.2147\n",
      "----------\n",
      "Epoch: 29, Phase: validate, epoch loss: 0.0089, Dice Score (class 0): 0.4576, Dice Score (class 1): 0.1791,Dice Score (class 2): 0.2617\n",
      "----------\n",
      "Epoch: 30, Phase: train, epoch loss: 0.0025, Dice Score (class 0): 0.3856, Dice Score (class 1): 0.3332,Dice Score (class 2): 0.2106\n",
      "----------\n",
      "Epoch: 30, Phase: validate, epoch loss: 0.0088, Dice Score (class 0): 0.4605, Dice Score (class 1): 0.2360,Dice Score (class 2): 0.2981\n",
      "----------\n",
      "Epoch: 31, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.4015, Dice Score (class 1): 0.3838,Dice Score (class 2): 0.2237\n",
      "----------\n",
      "Epoch: 31, Phase: validate, epoch loss: 0.0103, Dice Score (class 0): 0.4688, Dice Score (class 1): 0.4179,Dice Score (class 2): 0.3486\n",
      "----------\n",
      "Epoch: 32, Phase: train, epoch loss: 0.0022, Dice Score (class 0): 0.3904, Dice Score (class 1): 0.3807,Dice Score (class 2): 0.2404\n",
      "----------\n",
      "Epoch: 32, Phase: validate, epoch loss: 0.0085, Dice Score (class 0): 0.4480, Dice Score (class 1): 0.3601,Dice Score (class 2): 0.2842\n",
      "----------\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-80-d560073d377c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m                                                      \u001b[0mdataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_sizes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'new_data_unet_exp_noisy_1'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m                                                      \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m                                                      dice_loss = dice_loss_3,noisy_labels = True)\n\u001b[0m",
      "\u001b[0;32m<ipython-input-24-c237342f9bb3>\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, optimizer, dataloader, data_sizes, batch_size, name, num_epochs, verbose, dice_loss, noisy_labels)\u001b[0m\n\u001b[1;32m     24\u001b[0m                 \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdataloader\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     27\u001b[0m                 \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     28\u001b[0m                 \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegP\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msegT\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    177\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m             \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    180\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    181\u001b[0m                 \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/share/apps/pytorch/0.2.0_3/python3.6/lib/python3.6/site-packages/torch/utils/data/dataloader.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    177\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_workers\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# same-process loading\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    178\u001b[0m             \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample_iter\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# may raise StopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m             \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollate_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    180\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpin_memory\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    181\u001b[0m                 \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpin_memory_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-3-676504d23560>\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m     47\u001b[0m                 \u001b[0mfa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfliplr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfa\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotation\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m                 \u001b[0msegment_T\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_T\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m                 \u001b[0msegment_F\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_F\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m                 \u001b[0msegment_P\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrotate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msegment_P\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mangle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mrotate\u001b[0;34m(image, angle, resize, center, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m    298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    299\u001b[0m     return warp(image, tform, output_shape=output_shape, order=order,\n\u001b[0;32m--> 300\u001b[0;31m                 mode=mode, cval=cval, clip=clip, preserve_range=preserve_range)\n\u001b[0m\u001b[1;32m    301\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    302\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/pyenv/py3.6.3/lib/python3.6/site-packages/skimage/transform/_warps.py\u001b[0m in \u001b[0;36mwarp\u001b[0;34m(image, inverse_map, map_args, output_shape, order, mode, cval, clip, preserve_range)\u001b[0m\n\u001b[1;32m    767\u001b[0m                 warped = _warp_fast(image, matrix,\n\u001b[1;32m    768\u001b[0m                                  \u001b[0moutput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_shape\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 769\u001b[0;31m                                  order=order, mode=mode, cval=cval)\n\u001b[0m\u001b[1;32m    770\u001b[0m             \u001b[0;32melif\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    771\u001b[0m                 \u001b[0mdims\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32mskimage/transform/_warps_cy.pyx\u001b[0m in \u001b[0;36mskimage.transform._warps_cy._warp_fast\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m/share/apps/python3/3.6.3/intel/lib/python3.6/site-packages/numpy-1.13.3-py3.6-linux-x86_64.egg/numpy/core/numeric.py\u001b[0m in \u001b[0;36masarray\u001b[0;34m(a, dtype, order)\u001b[0m\n\u001b[1;32m    461\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 463\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0morder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    464\u001b[0m     \"\"\"Convert the input to an array.\n\u001b[1;32m    465\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "unet_exp_noisy, loss_hist_unet_exp, dc_hist_0_unet_exp, \\\n",
    "dc_hist_1_unet_exp, dc_hist_2_unet_exp = train_model(unet_exp_noisy, optimizer_unet_exp_noisy,\n",
    "                                                     dataloader,data_sizes,5,'new_data_unet_exp_noisy_1',\n",
    "                                                     num_epochs= 50, verbose = True, \n",
    "                                                     dice_loss = dice_loss_3,noisy_labels = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# torch.save(model_gen_dilated_l4_n2_new_data_dp,'new_data_dilated_net_l4_n2_nd_dp_1')\n",
    "pickle.dump(loss_hist_unet_exp, open('loss_hist_new_data_unet_exp_noisy_1','wb'))\n",
    "pickle.dump(dc_hist_0_unet_exp, open('dc_hist_0_new_data_unet_exp_noisy_1','wb'))\n",
    "pickle.dump(dc_hist_1_unet_exp, open('dc_hist_1_new_data_unet_exp_noisy_1','wb'))\n",
    "pickle.dump(dc_hist_2_unet_exp, open('dc_hist_2_new_data_unet_exp_noisy_1','wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plot_hist(loss_hist_unet_exp,'Loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "evaluate(unet_exp_noisy, dataloader, data_sizes, 5, 'validate', dice_loss=dice_loss_3, noisy_labels = True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}